-
Notifications
You must be signed in to change notification settings - Fork 303
Model Export to liteRT #2405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Model Export to liteRT #2405
Conversation
This reverts commit 62d2484.
This reverts commit de830b1.
Refactored exporter and registry logic for better type safety and error handling. Improved input signature methods in config classes by extracting sequence length logic. Enhanced LiteRT exporter with clearer verbose handling and stricter error reporting. Registry now conditionally registers LiteRT exporter and extends export method only if dependencies are available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @pctablet505, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a comprehensive and extensible framework for exporting Keras-Hub models to various formats, with an initial focus on LiteRT. The system is designed to seamlessly integrate with Keras-Hub's model architecture, particularly by addressing the unique challenge of handling dictionary-based model inputs during the export process. This enhancement significantly improves the deployability of Keras-Hub models by providing a standardized and robust export pipeline, alongside crucial compatibility fixes for TensorFlow's SavedModel/TFLite export mechanisms.
Highlights
- New Model Export Framework: Introduced a new, extensible framework for exporting Keras-Hub models, designed to support various formats and model types.
- LiteRT Export Support: Added specific support for exporting Keras-Hub models to the LiteRT format, verified for models like gemma3, llama3.2, and gpt2.
- Registry-Based Configuration: Implemented an
ExporterRegistryto manage and retrieve appropriate exporter configurations and exporters based on model type and target format. - Input Handling for Keras-Hub Models: Developed a
KerasHubModelWrapperto seamlessly convert Keras-Hub's dictionary-based inputs to the list-based inputs expected by the underlying Keras LiteRT exporter. - TensorFlow Export Compatibility: Added compatibility shims (
_get_save_specand_trackable_children) to Keras-HubBackbonemodels to ensure proper functioning with TensorFlow's SavedModel and TFLite export utilities. - Automated Export Method Extension: The
Taskclass in Keras-Hub models is now automatically extended with anexportmethod, simplifying the model export process for users.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant new feature: model exporting to liteRT. The implementation is well-structured, using a modular and extensible registry pattern. However, there are several areas that require attention. The most critical issue is the complete absence of tests for the new export functionality, which is a direct violation of the repository's style guide stating that testing is non-negotiable. Additionally, I've identified a critical bug in the error handling logic within the lite_rt.py exporter that includes unreachable code. There are also several violations of the style guide regarding the use of type hints in function signatures across all new files. I've provided specific comments and suggestions to address these points, which should help improve the robustness, maintainability, and compliance of this new feature.
keras_hub/src/export/configs.py
Outdated
| def _get_sequence_length(self) -> int: | ||
| """Get sequence length from model or use default.""" | ||
| if hasattr(self.model, 'preprocessor') and self.model.preprocessor: | ||
| return getattr(self.model.preprocessor, 'sequence_length', self.DEFAULT_SEQUENCE_LENGTH) | ||
| return self.DEFAULT_SEQUENCE_LENGTH |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _get_sequence_length method is duplicated across CausalLMExporterConfig, TextClassifierExporterConfig, Seq2SeqLMExporterConfig, and TextModelExporterConfig. To improve maintainability and reduce code duplication, this method should be moved to the base class KerasHubExporterConfig in keras_hub/src/export/base.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have different kinds of models in keras-hub, some deal with text and have sequence length, while other models don't have that, so we currently can't generalize it for all models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't be good choice for now, as we don't have this for models that don't have sequence_length parameter for example image segmentation model.
Introduces the keras_hub.api.export submodule and updates the main API to expose it. The new export module imports various exporter configs and functions from the internal export package, making them available through the public API.
Added ImageClassifierExporterConfig, ImageSegmenterExporterConfig, and ObjectDetectorExporterConfig to the export API. Improved input shape inference and dummy input generation for image-related exporter configs. Refactored LiteRTExporter to better handle model type checks and input signature logic, with improved error handling for input mapping.
Moved the 'import keras' statement to the top of the module and removed redundant local imports within class methods. This improves code clarity and avoids repeated imports.
Deleted the debug_object_detection.py script, which was used for testing object detection model outputs and export issues. This cleanup removes unused debugging code from the repository.
Renames all references of 'LiteRT' to 'Litert' across the codebase, including file names, class names, and function names. Updates exporter registry and API imports to use the new 'litert' naming. Also improves image model exporter configs to dynamically determine input dtype from the model, enhancing flexibility for different input types. Adds support for ImageSegmenter model type detection in the exporter registry.
Refactored InputSpec definitions in exporter configs for improved readability by placing each argument on a separate line. Updated import path in litert.py to import from keras.src.export.litert instead of keras.src.export.litert_exporter.
|
@pctablet505 can you update the Colab to use the changes from this PR? and keep the demo short where you load a model and export it, and reload and verify numerics? |
|
Also if you have addressed the Gemini comments, please respond on it. Is this PR dependent on the PR in Keras repo? - keras-team/keras#21674 |
Simplifies and unifies input signature and dummy input generation for text and image models by removing redundant helper methods and centralizing logic. Updates model building in KerasHubExporter to use input signatures and improves error handling. Refactors LiteRT exporter to use the new parameterized input signature and model building approach, reducing code duplication and improving maintainability.
Removed redundant registry initialization in export_model and clarified model building comments in KerasHubExporter. Switched to using warnings.warn for import errors in models/__init__.py instead of print statements for better error reporting.
|
@gemini-code-assist please review it again |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is a significant and valuable addition to KerasHub, enabling model exportation to LiteRT. The overall structure is well-designed, following a modular and extensible pattern similar to Hugging Face's Optimum.
I've reviewed this pull request, focusing on correctness, maintainability, and adherence to the repository's style guide. The style guide was instrumental in this review, particularly the sections on Testing and Validation (lines 401-496) and Docstrings and Type Hinting (lines 364-398).
My main findings are:
-
Critical Issue: Missing Tests
- The pull request introduces substantial new functionality for model exporting across several new files (
base.py,configs.py,litert.py,registry.py), but it lacks corresponding tests. - The repository style guide is explicit that "Testing is a non-negotiable part of every contribution" (line 403) and "Every .py file containing logic...must have a corresponding
_test.pyfile" (line 406). - Please add comprehensive unit tests for the new export logic, covering different model types, configurations, and edge cases. This is crucial to ensure the robustness and correctness of this feature.
- The pull request introduces substantial new functionality for model exporting across several new files (
-
Other Findings
- I've also left several inline comments regarding a bug in model type detection, incorrect dtype handling, and violations of the docstring style guide. Please address these to improve code quality and consistency.
Refined dtype extraction logic in image and object model exporter configs to better handle different dtype representations. Updated LiteRT exporter to use Keras io_utils for progress messages and improved verbose flag handling. Added ObjectDetector and ImageSegmenter to export registry model type checks. Enhanced docstrings for clarity and consistency in base exporter classes.
Eliminates the get_dummy_inputs methods from KerasHubExporterConfig and its subclasses. Model building now relies solely on shape-based initialization, simplifying the export process and removing fallback logic for dummy data.
|
@gemini-code-assist please review the code changes. And see if something can be improved or optimized, or removed if not necessary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a well-structured and extensible export functionality for Keras-Hub models, with an initial focus on LiteRT. The design, which utilizes configurations, exporters, and a registry, is solid. The changes also include necessary compatibility shims in the Backbone class to support TensorFlow export. My review focuses on enhancing robustness and ensuring adherence to the repository's style guide. I have identified a few areas for improvement, such as replacing fragile error-message parsing with more direct introspection, addressing the use of private APIs, and ensuring docstrings fully conform to the specified style.
keras_hub/src/export/litert.py
Outdated
| except ValueError as e: | ||
| error_msg = str(e) | ||
| # If that fails, try direct tensor input (positional format) | ||
| if ( | ||
| "doesn't match the expected structure" in error_msg | ||
| and "Expected: keras_tensor" in error_msg | ||
| ): | ||
| # The model expects a single tensor, not a dictionary | ||
| if len(inputs) == 1: | ||
| return self.keras_hub_model( | ||
| inputs[0], training=training, mask=mask | ||
| ) | ||
| else: | ||
| # Multiple inputs - try as positional arguments | ||
| return self.keras_hub_model( | ||
| *inputs, training=training, mask=mask | ||
| ) | ||
| elif "Missing data for input" in error_msg: | ||
| # Extract the actual expected input names from the error | ||
| if "Expected the following keys:" in error_msg: | ||
| # Parse the expected keys from error message | ||
| start = error_msg.find( | ||
| "Expected the following keys: [" | ||
| ) | ||
| if start != -1: | ||
| start += len("Expected the following keys: [") | ||
| end = error_msg.find("]", start) | ||
| if end != -1: | ||
| keys_str = error_msg[start:end] | ||
| actual_input_names = [ | ||
| k.strip().strip("'\"") | ||
| for k in keys_str.split(",") | ||
| ] | ||
|
|
||
| # Map inputs to actual expected names | ||
| input_dict = {} | ||
| for i, actual_name in enumerate( | ||
| actual_input_names | ||
| ): | ||
| if i < len(inputs): | ||
| input_dict[actual_name] = inputs[i] | ||
|
|
||
| return self.keras_hub_model( | ||
| input_dict, training=training, mask=mask | ||
| ) | ||
|
|
||
| # If we still can't figure it out, re-raise the original | ||
| # error | ||
| raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error handling logic in the call method of KerasHubModelWrapper relies on parsing exception messages (e.g., "doesn't match the expected structure", "Missing data for input"). This is a brittle approach, as the error messages from underlying libraries like Keras or TensorFlow can change in future versions, which would break this export functionality.
A more robust approach would be to programmatically inspect the model's input signature instead of reacting to errors. Here are some suggestions:
- For functional models, you can use
self.keras_hub_model.input_namesto get the expected keys. - For subclassed models, you could use Python's
inspect.signature(self.keras_hub_model.call)to determine the expected arguments.
By proactively determining the input structure, you can avoid this fragile error parsing and make the exporter more resilient to changes in dependencies.
| # Import _DictWrapper safely | ||
| try: | ||
| from tensorflow.python.trackable.data_structures import _DictWrapper | ||
| except ImportError: | ||
| return children |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method imports _DictWrapper from a private TensorFlow module (tensorflow.python.trackable.data_structures). Relying on private APIs is risky because they are not guaranteed to be stable and can be changed or removed without notice in future TensorFlow releases, which could break the export functionality.
While I understand this might be a necessary workaround for current issues with TensorFlow export, it would be ideal to find a solution that uses public APIs if possible. If no public API is available, consider adding a comment here warning future developers about the dependency on a private API and potentially pinning the TensorFlow version more strictly if this is critical.
Refactored all references and class names from LitertExporter to LiteRTExporter for consistency with Keras naming conventions. This affects imports, class definitions, and usage throughout the export modules.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive model export functionality to liteRT, which is a significant and valuable addition. The implementation includes base classes for exporters and configurations, a registry for managing them, and specific logic for liteRT export. The approach of extending Backbone and Task classes to integrate this functionality is well-thought-out. My review focuses on enhancing the robustness and maintainability of the new code. Specifically, I've provided suggestions to make model type detection less brittle, correct a potential data type issue with padding_mask, reduce code duplication in configuration classes, and improve the input handling logic in the model wrapper to avoid fragile error-parsing.
keras_hub/src/export/litert.py
Outdated
| except ValueError as e: | ||
| error_msg = str(e) | ||
| # If that fails, try direct tensor input (positional format) | ||
| if ( | ||
| "doesn't match the expected structure" in error_msg | ||
| and "Expected: keras_tensor" in error_msg | ||
| ): | ||
| # The model expects a single tensor, not a dictionary | ||
| if len(inputs) == 1: | ||
| return self.keras_hub_model( | ||
| inputs[0], training=training, mask=mask | ||
| ) | ||
| else: | ||
| # Multiple inputs - try as positional arguments | ||
| return self.keras_hub_model( | ||
| *inputs, training=training, mask=mask | ||
| ) | ||
| elif "Missing data for input" in error_msg: | ||
| # Extract the actual expected input names from the error | ||
| if "Expected the following keys:" in error_msg: | ||
| # Parse the expected keys from error message | ||
| start = error_msg.find( | ||
| "Expected the following keys: [" | ||
| ) | ||
| if start != -1: | ||
| start += len("Expected the following keys: [") | ||
| end = error_msg.find("]", start) | ||
| if end != -1: | ||
| keys_str = error_msg[start:end] | ||
| actual_input_names = [ | ||
| k.strip().strip("'\"") | ||
| for k in keys_str.split(",") | ||
| ] | ||
|
|
||
| # Map inputs to actual expected names | ||
| input_dict = {} | ||
| for i, actual_name in enumerate( | ||
| actual_input_names | ||
| ): | ||
| if i < len(inputs): | ||
| input_dict[actual_name] = inputs[i] | ||
|
|
||
| return self.keras_hub_model( | ||
| input_dict, training=training, mask=mask | ||
| ) | ||
|
|
||
| # If we still can't figure it out, re-raise the original | ||
| # error | ||
| raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call method of the KerasHubModelWrapper relies on parsing error message strings to determine the correct input format for the underlying model. This is very fragile and can easily break with changes in Keras error messages. A more robust approach would be to define the expected input format (e.g., 'dict', 'list', 'tensor') as a property on the KerasHubExporterConfig. The wrapper could then use this property to format the inputs correctly, avoiding the current trial-and-error approach.
keras_hub/src/export/base.py
Outdated
| model_class_name = model.__class__.__name__ | ||
|
|
||
| if CausalLM and isinstance(model, CausalLM): | ||
| return "causal_lm" | ||
| elif "TextClassifier" in model_class_name: | ||
| return "text_classifier" | ||
| elif Seq2SeqLM and isinstance(model, Seq2SeqLM): | ||
| return "seq2seq_lm" | ||
| elif "ImageClassifier" in model_class_name: | ||
| return "image_classifier" | ||
| elif ObjectDetector and isinstance(model, ObjectDetector): | ||
| return "object_detector" | ||
| elif "ObjectDetector" in model_class_name: | ||
| return "object_detector" | ||
| elif ImageSegmenter and isinstance(model, ImageSegmenter): | ||
| return "image_segmenter" | ||
| elif "ImageSegmenter" in model_class_name: | ||
| return "image_segmenter" | ||
| else: | ||
| # Default to text model for generic Keras-Hub models | ||
| return "text_model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model type detection logic is brittle as it mixes isinstance checks with string-based checks on the class name. It would be more robust to consistently use isinstance with the base class for each model type. This can be achieved by importing TextClassifier and ImageClassifier and using isinstance for them, while also cleaning up the redundant checks for ObjectDetector and ImageSegmenter.
| def get_input_signature(self, image_size=None): | ||
| """Get input signature for image classifier models. | ||
| Args: | ||
| image_size: Optional image size. If None, inferred from model. | ||
| Returns: | ||
| Dict[str, Any]: Dictionary mapping input names to their | ||
| specifications | ||
| """ | ||
| if image_size is None: | ||
| # Get from preprocessor | ||
| if hasattr(self.model, "preprocessor") and self.model.preprocessor: | ||
| if hasattr(self.model.preprocessor, "image_size"): | ||
| image_size = self.model.preprocessor.image_size | ||
|
|
||
| # Try to infer from model inputs | ||
| if ( | ||
| image_size is None | ||
| and hasattr(self.model, "inputs") | ||
| and self.model.inputs | ||
| ): | ||
| input_shape = self.model.inputs[0].shape | ||
| if ( | ||
| len(input_shape) == 4 | ||
| and input_shape[1] is not None | ||
| and input_shape[2] is not None | ||
| ): | ||
| image_size = (input_shape[1], input_shape[2]) | ||
|
|
||
| if image_size is None: | ||
| raise ValueError( | ||
| "Could not determine image size from model. " | ||
| "Model should have a preprocessor with image_size " | ||
| "attribute, or model inputs should have concrete shapes." | ||
| ) | ||
|
|
||
| if isinstance(image_size, int): | ||
| image_size = (image_size, image_size) | ||
|
|
||
| # Get input dtype | ||
| dtype = "float32" | ||
| if hasattr(self.model, "inputs") and self.model.inputs: | ||
| model_dtype = self.model.inputs[0].dtype | ||
| dtype = ( | ||
| model_dtype.name | ||
| if hasattr(model_dtype, "name") | ||
| else model_dtype | ||
| ) | ||
|
|
||
| return { | ||
| "images": keras.layers.InputSpec( | ||
| shape=(None, *image_size, 3), | ||
| dtype=dtype, | ||
| name="images", | ||
| ), | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for inferring image_size and dtype is duplicated across ImageClassifierExporterConfig, ObjectDetectorExporterConfig, and ImageSegmenterExporterConfig. This duplicated code should be refactored into a private helper function within this module to improve maintainability and reduce redundancy.
|
@pctablet505 can you please clean up the colab demo? it is massive! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test coverage for the changes in this PR is not sufficient. Please add more test coverage.
| """Exporter configuration for Causal Language Models (GPT, LLaMA, etc.).""" | ||
|
|
||
| MODEL_TYPE = "causal_lm" | ||
| EXPECTED_INPUTS = ["token_ids", "padding_mask"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this hardcoding may not work with all the causal models
Changed exporter registry and config registration to use model classes instead of string type names for improved type safety and clarity. Updated input signature methods to use isinstance checks and standardized padding_mask dtype to int32. Enhanced LiteRTExporter to dynamically determine input signature parameters based on model type and preprocessor attributes.
Replaces the try-except block for importing keras with a direct import, assuming keras is always available. Simplifies the code and removes the KERAS_AVAILABLE flag.
Introduce new test modules for export base classes, configuration classes, LiteRT export functionality, registry logic, and production model export verification. Also update TensorFlow CUDA requirements to include ai-edge-litert for LiteRT export support.
Replaces the previous wrapper with type-specific adapter classes for text and image models in the LiteRT exporter, improving input conversion logic and maintainability. Also updates docstrings and return type annotations for consistency across exporter config classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a robust and well-structured framework for exporting Keras Hub models to the LiteRT format. The design, which utilizes a registry, exporter configurations, and model-specific adapters, is commendable for its extensibility and clear separation of concerns. The addition of comprehensive unit and integration tests, including numerical verification for production models, significantly increases confidence in this new feature.
My review focuses on improving maintainability by reducing code duplication in the configuration classes, enhancing the flexibility of the exporter registry, and fixing a minor bug in the tests. I've also suggested opportunities to simplify some of the implementation details. Overall, this is a high-quality contribution that adds significant value to Keras Hub.
| for model_class in [ | ||
| CausalLM, | ||
| TextClassifier, | ||
| Seq2SeqLM, | ||
| ImageClassifier, | ||
| ObjectDetector, | ||
| ImageSegmenter, | ||
| ]: | ||
| if isinstance(model, model_class): | ||
| if model_class not in cls._configs: | ||
| raise ValueError( | ||
| f"No configuration found for model type: " | ||
| f"{model_class.__name__}" | ||
| ) | ||
| config_class = cls._configs[model_class] | ||
| return config_class(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_config_for_model method uses a hardcoded list of model classes for type detection. This approach is not easily extensible and requires this list to be manually updated whenever a new model type is added, which is error-prone.
A more robust and maintainable approach would be to leverage the registry itself. You can iterate over cls._configs.items() to perform the isinstance check against the registered model classes. This would make the registry truly plug-and-play, as simply registering a new model-config pair would be sufficient for it to be discovered.
for model_class, config_class in cls._configs.items():
if isinstance(model, model_class):
return config_class(model)
keras_hub/src/export/configs.py
Outdated
| def get_input_signature(self, sequence_length=None): | ||
| """Get input signature for causal LM models. | ||
| Args: | ||
| sequence_length: Optional sequence length. If None, uses default. | ||
| Returns: | ||
| `dict`. Dictionary mapping input names to their specifications | ||
| """ | ||
| if sequence_length is None: | ||
| # Get from preprocessor or use default | ||
| if hasattr(self.model, "preprocessor") and self.model.preprocessor: | ||
| sequence_length = getattr( | ||
| self.model.preprocessor, | ||
| "sequence_length", | ||
| self.DEFAULT_SEQUENCE_LENGTH, | ||
| ) | ||
| else: | ||
| sequence_length = self.DEFAULT_SEQUENCE_LENGTH | ||
|
|
||
| return { | ||
| "token_ids": keras.layers.InputSpec( | ||
| shape=(None, sequence_length), dtype="int32", name="token_ids" | ||
| ), | ||
| "padding_mask": keras.layers.InputSpec( | ||
| shape=(None, sequence_length), | ||
| dtype="int32", | ||
| name="padding_mask", | ||
| ), | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some code duplication in the get_input_signature method across CausalLMExporterConfig, TextClassifierExporterConfig, and TextModelExporterConfig. To improve maintainability, consider creating a shared base class, for example TextExporterConfig, that implements this common logic. The specific text model configs can then inherit from it.
| def get_input_signature(self, image_size=None): | ||
| """Get input signature for image classifier models. | ||
| Args: | ||
| image_size: Optional image size. If None, inferred from model. | ||
| Returns: | ||
| `dict`. Dictionary mapping input names to their specifications | ||
| """ | ||
| if image_size is None: | ||
| # Get from preprocessor | ||
| if hasattr(self.model, "preprocessor") and self.model.preprocessor: | ||
| if hasattr(self.model.preprocessor, "image_size"): | ||
| image_size = self.model.preprocessor.image_size | ||
|
|
||
| # Try to infer from model inputs | ||
| if ( | ||
| image_size is None | ||
| and hasattr(self.model, "inputs") | ||
| and self.model.inputs | ||
| ): | ||
| input_shape = self.model.inputs[0].shape | ||
| if ( | ||
| len(input_shape) == 4 | ||
| and input_shape[1] is not None | ||
| and input_shape[2] is not None | ||
| ): | ||
| image_size = (input_shape[1], input_shape[2]) | ||
|
|
||
| if image_size is None: | ||
| raise ValueError( | ||
| "Could not determine image size from model. " | ||
| "Model should have a preprocessor with image_size " | ||
| "attribute, or model inputs should have concrete shapes." | ||
| ) | ||
|
|
||
| if isinstance(image_size, int): | ||
| image_size = (image_size, image_size) | ||
|
|
||
| # Get input dtype | ||
| dtype = "float32" | ||
| if hasattr(self.model, "inputs") and self.model.inputs: | ||
| model_dtype = self.model.inputs[0].dtype | ||
| dtype = ( | ||
| model_dtype.name | ||
| if hasattr(model_dtype, "name") | ||
| else model_dtype | ||
| ) | ||
|
|
||
| return { | ||
| "images": keras.layers.InputSpec( | ||
| shape=(None, *image_size, 3), | ||
| dtype=dtype, | ||
| name="images", | ||
| ), | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for get_input_signature is very similar across the image-based exporter configurations (ImageClassifierExporterConfig, ObjectDetectorExporterConfig, ImageSegmenterExporterConfig), especially the parts that determine image_size and dtype. To avoid code duplication, you could extract this common logic into a base class like ImageExporterConfig. The child classes would then only need to define their specific EXPECTED_INPUTS and construct the final signature dictionary.
| # Determine the parameter to pass based on model type using isinstance | ||
| is_text_model = isinstance( | ||
| self.model, (CausalLM, TextClassifier, Seq2SeqLM) | ||
| ) | ||
| is_image_model = isinstance( | ||
| self.model, (ImageClassifier, ObjectDetector, ImageSegmenter) | ||
| ) | ||
|
|
||
| # For text models, use sequence_length; for image models, get image_size | ||
| # from preprocessor | ||
| if is_text_model: | ||
| param = self.max_sequence_length | ||
| elif is_image_model: | ||
| # Get image_size from model's preprocessor | ||
| if hasattr(self.model, "preprocessor") and hasattr( | ||
| self.model.preprocessor, "image_size" | ||
| ): | ||
| param = self.model.preprocessor.image_size | ||
| else: | ||
| param = None # Will use default in get_input_signature | ||
| else: | ||
| param = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for determining the model type (is_text_model, is_image_model) and the param for get_input_signature is duplicated in export() (here) and _create_export_wrapper(). This could lead to inconsistencies if one is updated and the other is not. Consider extracting this logic into a private helper method, e.g., _get_model_type_and_param(), which can be called from both places to ensure consistency and reduce code duplication.
| def _is_keras_hub_model(self): | ||
| """Check if this model is a Keras-Hub model that needs special | ||
| handling.""" | ||
| if hasattr(self, "__class__"): | ||
| class_name = self.__class__.__name__ | ||
| module_name = self.__class__.__module__ | ||
|
|
||
| # Check if it's from keras_hub package | ||
| if "keras_hub" in module_name: | ||
| return True | ||
|
|
||
| # Check if it has keras-hub specific attributes | ||
| if hasattr(self, "preprocessor") and hasattr(self, "backbone"): | ||
| return True | ||
|
|
||
| # Check for common Keras-Hub model names | ||
| keras_hub_model_names = [ | ||
| "CausalLM", | ||
| "Seq2SeqLM", | ||
| "TextClassifier", | ||
| "ImageClassifier", | ||
| "ObjectDetector", | ||
| "ImageSegmenter", | ||
| ] | ||
| if any(name in class_name for name in keras_hub_model_names): | ||
| return True | ||
|
|
||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _is_keras_hub_model method uses several heuristic checks to determine if a model is a Keras Hub model. Since this method is being attached to the Task class, a more direct and robust approach would be to check isinstance(self, Task). This simplifies the code and makes it less fragile to changes in module structure or class names.
def _is_keras_hub_model(self):
"""Check if this model is a Keras-Hub model that needs special
handling."""
return isinstance(self, Task)|
@gemini-code-assist review the code, and find the remaining issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive model export functionality to liteRT, which is a significant and valuable addition. The implementation is well-structured, following a registry pattern with specific configurations for different model types, which is great for extensibility. The code is accompanied by a good set of unit and integration tests.
My review focuses on improving maintainability by reducing code duplication, fixing a bug in the model type detection logic, and correcting a few issues in the tests. These changes will make the new export feature more robust and easier to maintain in the future.
| for model_class in [ | ||
| CausalLM, | ||
| TextClassifier, | ||
| Seq2SeqLM, | ||
| ImageClassifier, | ||
| ObjectDetector, | ||
| ImageSegmenter, | ||
| ]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order of model classes in this list is crucial for correct config selection, as isinstance() checks are performed sequentially. Seq2SeqLM is a subclass of CausalLM, so a Seq2SeqLM model will be incorrectly identified as a CausalLM because isinstance(model, CausalLM) will be True. This will result in the wrong exporter configuration being used.
To fix this, subclasses should always be checked before their parent classes. Please reorder the list to place Seq2SeqLM before CausalLM.
| for model_class in [ | |
| CausalLM, | |
| TextClassifier, | |
| Seq2SeqLM, | |
| ImageClassifier, | |
| ObjectDetector, | |
| ImageSegmenter, | |
| ]: | |
| for model_class in [ | |
| Seq2SeqLM, | |
| CausalLM, | |
| TextClassifier, | |
| ImageClassifier, | |
| ObjectDetector, | |
| ImageSegmenter, | |
| ]: |
| def get_input_signature(self, image_size=None): | ||
| """Get input signature for image classifier models. | ||
| Args: | ||
| image_size: Optional image size. If None, inferred from model. | ||
| Returns: | ||
| `dict`. Dictionary mapping input names to their specifications | ||
| """ | ||
| if image_size is None: | ||
| # Get from preprocessor | ||
| if hasattr(self.model, "preprocessor") and self.model.preprocessor: | ||
| if hasattr(self.model.preprocessor, "image_size"): | ||
| image_size = self.model.preprocessor.image_size | ||
|
|
||
| # Try to infer from model inputs | ||
| if ( | ||
| image_size is None | ||
| and hasattr(self.model, "inputs") | ||
| and self.model.inputs | ||
| ): | ||
| input_shape = self.model.inputs[0].shape | ||
| if ( | ||
| len(input_shape) == 4 | ||
| and input_shape[1] is not None | ||
| and input_shape[2] is not None | ||
| ): | ||
| image_size = (input_shape[1], input_shape[2]) | ||
|
|
||
| if image_size is None: | ||
| raise ValueError( | ||
| "Could not determine image size from model. " | ||
| "Model should have a preprocessor with image_size " | ||
| "attribute, or model inputs should have concrete shapes." | ||
| ) | ||
|
|
||
| if isinstance(image_size, int): | ||
| image_size = (image_size, image_size) | ||
|
|
||
| # Get input dtype | ||
| dtype = "float32" | ||
| if hasattr(self.model, "inputs") and self.model.inputs: | ||
| model_dtype = self.model.inputs[0].dtype | ||
| dtype = ( | ||
| model_dtype.name | ||
| if hasattr(model_dtype, "name") | ||
| else model_dtype | ||
| ) | ||
|
|
||
| return { | ||
| "images": keras.layers.InputSpec( | ||
| shape=(None, *image_size, 3), | ||
| dtype=dtype, | ||
| name="images", | ||
| ), | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic within this get_input_signature method for determining image_size and input dtype is duplicated in ObjectDetectorExporterConfig and ImageSegmenterExporterConfig.
To improve maintainability and reduce redundancy, consider extracting this shared logic into a common base class (e.g., ImageModelExporterConfig). The individual image model configuration classes (ImageClassifierExporterConfig, ObjectDetectorExporterConfig, ImageSegmenterExporterConfig) could then inherit from this new base class.
This would centralize the logic, making future changes easier.
| # Determine the parameter to pass based on model type using isinstance | ||
| is_text_model = isinstance( | ||
| self.model, (CausalLM, TextClassifier, Seq2SeqLM) | ||
| ) | ||
| is_image_model = isinstance( | ||
| self.model, (ImageClassifier, ObjectDetector, ImageSegmenter) | ||
| ) | ||
|
|
||
| # For text models, use sequence_length; for image models, get image_size | ||
| # from preprocessor | ||
| if is_text_model: | ||
| param = self.max_sequence_length | ||
| elif is_image_model: | ||
| # Get image_size from model's preprocessor | ||
| if hasattr(self.model, "preprocessor") and hasattr( | ||
| self.model.preprocessor, "image_size" | ||
| ): | ||
| param = self.model.preprocessor.image_size | ||
| else: | ||
| param = None # Will use default in get_input_signature | ||
| else: | ||
| param = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to determine the model type (is_text_model, is_image_model) and the corresponding param for _ensure_model_built is duplicated in _create_export_wrapper. This redundancy makes the code harder to maintain.
This logic could be refactored into a helper method or by adding properties to the KerasHubExporterConfig base class. For instance, you could add is_text_model and is_image_model properties to KerasHubExporterConfig to centralize this check.
Example of properties in KerasHubExporterConfig:
@property
def is_text_model(self):
return isinstance(self.model, (CausalLM, TextClassifier, Seq2SeqLM))
@property
def is_image_model(self):
return isinstance(self.model, (ImageClassifier, ObjectDetector, ImageSegmenter))Then you could simplify the code here to:
if self.config.is_text_model:
param = self.max_sequence_length
elif self.config.is_image_model:
# ...This would remove duplication and make the code cleaner.
| def _is_keras_hub_model(self): | ||
| """Check if this model is a Keras-Hub model that needs special | ||
| handling.""" | ||
| if hasattr(self, "__class__"): | ||
| class_name = self.__class__.__name__ | ||
| module_name = self.__class__.__module__ | ||
|
|
||
| # Check if it's from keras_hub package | ||
| if "keras_hub" in module_name: | ||
| return True | ||
|
|
||
| # Check if it has keras-hub specific attributes | ||
| if hasattr(self, "preprocessor") and hasattr(self, "backbone"): | ||
| return True | ||
|
|
||
| # Check for common Keras-Hub model names | ||
| keras_hub_model_names = [ | ||
| "CausalLM", | ||
| "Seq2SeqLM", | ||
| "TextClassifier", | ||
| "ImageClassifier", | ||
| "ObjectDetector", | ||
| "ImageSegmenter", | ||
| ] | ||
| if any(name in class_name for name in keras_hub_model_names): | ||
| return True | ||
|
|
||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _is_keras_hub_model method seems overly complex. Since this method is monkey-patched onto the Task class, self will always be an instance of a Task subclass. The Task class is defined within the keras_hub package, so self.__class__.__module__ will always contain "keras_hub". This means the first check if "keras_hub" in module_name: will always evaluate to True, making the subsequent checks redundant.
The method can be simplified to just return True, or even be removed entirely by changing the call site in keras_hub_export to if format == "litert":. This would make the code simpler and more robust.
def _is_keras_hub_model(self):
"""Check if this model is a Keras-Hub model that needs special
handling."""
# This method is patched onto `keras_hub.src.models.task.Task`,
# so any object `self` will be an instance of a Keras-Hub model.
return TrueUpdated docstrings in base.py, configs.py, and litert.py to specify explicit type annotations for function arguments and return values. This improves code readability and helps developers understand expected input types for exporter configuration and usage.
Ensure exported model filepath ends with '.tflite' and update verbose message to reflect correct path. Remove unused 'mask' argument from model calls in LiteRTExporter adapters. Update test to use input dtypes from interpreter for test inputs.
Simplifies calls to keras_hub_model in TextModelAdapter and ImageModelAdapter by removing unnecessary line breaks and grouping arguments. Also updates string quotes for consistency and improves formatting in litert_test.py for readability.
|
@gemini-code-assist |
|
@pctablet505 I can provide a new code review for the current state of the pull request, but I cannot directly check previous review comments or their resolution status. To initiate a new review, please use the |
|
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a comprehensive model export functionality for liteRT, which is a significant and valuable addition to keras-hub. The implementation is well-structured, following a modular design with base classes, configurations, and a registry. The inclusion of extensive unit and integration tests, including numerical verification for production models, is commendable and ensures the reliability of the new feature.
I have provided a few suggestions aimed at improving maintainability by reducing code duplication and simplifying some of the logic. These are minor refactorings that can enhance the clarity and robustness of the codebase. Overall, this is an excellent contribution.
| if sequence_length is None: | ||
| if hasattr(self.model, "preprocessor") and self.model.preprocessor: | ||
| sequence_length = getattr( | ||
| self.model.preprocessor, | ||
| "sequence_length", | ||
| self.DEFAULT_SEQUENCE_LENGTH, | ||
| ) | ||
| else: | ||
| sequence_length = self.DEFAULT_SEQUENCE_LENGTH | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's some duplicated logic for inferring sequence_length across several text model configuration classes (CausalLMExporterConfig, TextClassifierExporterConfig, Seq2SeqLMExporterConfig, TextModelExporterConfig).
To improve maintainability and reduce code duplication, you could extract this logic into a shared helper function. For example:
def _infer_sequence_length(model, default_length):
"""Infer sequence length from model preprocessor or use default."""
if hasattr(model, "preprocessor") and model.preprocessor:
return getattr(
model.preprocessor,
"sequence_length",
default_length,
)
return default_length
# Then, in each config class:
class CausalLMExporterConfig(KerasHubExporterConfig):
# ...
def get_input_signature(self, sequence_length=None):
# ...
if sequence_length is None:
sequence_length = _infer_sequence_length(
self.model, self.DEFAULT_SEQUENCE_LENGTH
)
return _get_text_input_signature(self.model, sequence_length)This would make the code more DRY and easier to update in the future.
| dtype = "float32" | ||
| if hasattr(self.model, "inputs") and self.model.inputs: | ||
| model_dtype = self.model.inputs[0].dtype | ||
| dtype = ( | ||
| model_dtype.name | ||
| if hasattr(model_dtype, "name") | ||
| else model_dtype | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for inferring the input dtype is duplicated across the image model configuration classes (ImageClassifierExporterConfig, ObjectDetectorExporterConfig, ImageSegmenterExporterConfig).
To avoid repetition and improve maintainability, you could extract this into a private helper function, similar to _infer_image_size.
For example:
def _infer_image_dtype(model):
"""Infer image dtype from model inputs."""
if hasattr(model, "inputs") and model.inputs:
model_dtype = model.inputs[0].dtype
return model_dtype.name if hasattr(model_dtype, "name") else model_dtype
return "float32"
# Then, in each image config class:
class ImageClassifierExporterConfig(KerasHubExporterConfig):
# ...
def get_input_signature(self, image_size=None):
# ...
dtype = _infer_image_dtype(self.model)
return {
"images": keras.layers.InputSpec(
shape=(None, *image_size, 3),
dtype=dtype,
name="images",
),
}This change would centralize the dtype inference logic.
| # For text models, use sequence_length; for image models, get image_size | ||
| # from preprocessor | ||
| if is_text_model: | ||
| param = self.max_sequence_length | ||
| elif is_image_model: | ||
| # Get image_size from model's preprocessor | ||
| if hasattr(self.model, "preprocessor") and hasattr( | ||
| self.model.preprocessor, "image_size" | ||
| ): | ||
| param = self.model.preprocessor.image_size | ||
| else: | ||
| param = None # Will use default in get_input_signature | ||
| else: | ||
| param = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to determine the param (which can be max_sequence_length for text models or image_size for image models) is duplicated. It appears once in the export method (lines 84-105) and again in _create_export_wrapper (lines 260-279).
To simplify the code and remove this redundancy, you can determine param once in the export method and then pass it as an argument to _create_export_wrapper.
For example:
# In export() method
...
param = # ... logic to determine param ...
self._ensure_model_built(param)
input_signature = self.config.get_input_signature(param)
wrapped_model = self._create_export_wrapper(param)
...
# In _create_export_wrapper() method
def _create_export_wrapper(self, param):
...
# No need to determine param again, just use it.
return adapter_class(
self.model,
self.config.EXPECTED_INPUTS,
self.config.get_input_signature(param),
)This refactoring would make the code cleaner and more maintainable.
| if hasattr(self, "__class__"): | ||
| class_name = self.__class__.__name__ | ||
| module_name = self.__class__.__module__ | ||
|
|
||
| # Check if it's from keras_hub package | ||
| if "keras_hub" in module_name: | ||
| return True | ||
|
|
||
| # Check if it has keras-hub specific attributes | ||
| if hasattr(self, "preprocessor") and hasattr(self, "backbone"): | ||
| return True | ||
|
|
||
| # Check for common Keras-Hub model names | ||
| keras_hub_model_names = [ | ||
| "CausalLM", | ||
| "Seq2SeqLM", | ||
| "TextClassifier", | ||
| "ImageClassifier", | ||
| "ObjectDetector", | ||
| "ImageSegmenter", | ||
| ] | ||
| if any(name in class_name for name in keras_hub_model_names): | ||
| return True | ||
|
|
||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check if hasattr(self, "__class__") is redundant, as all Python objects have a __class__ attribute. You can safely remove this if statement and un-indent the code inside it.
class_name = self.__class__.__name__
module_name = self.__class__.__module__
# Check if it's from keras_hub package
if "keras_hub" in module_name:
return True
# Check if it has keras-hub specific attributes
if hasattr(self, "preprocessor") and hasattr(self, "backbone"):
return True
# Check for common Keras-Hub model names
keras_hub_model_names = [
"CausalLM",
"Seq2SeqLM",
"TextClassifier",
"ImageClassifier",
"ObjectDetector",
"ImageSegmenter",
]
if any(name in class_name for name in keras_hub_model_names):
return True
return False| # For list-like _DictWrapper (e.g., transformer_layers) | ||
| if hasattr(child, "_data") and isinstance( | ||
| child._data, list | ||
| ): | ||
| # Create a clean list of the trackable items | ||
| clean_list = [] | ||
| for item in child._data: | ||
| if hasattr(item, "_trackable_children"): | ||
| clean_list.append(item) | ||
| if clean_list: | ||
| clean_children[name] = clean_list | ||
| # For dict-like _DictWrapper | ||
| elif hasattr(child, "_data") and isinstance( | ||
| child._data, dict | ||
| ): | ||
| clean_dict = {} | ||
| for k, v in child._data.items(): | ||
| if hasattr(v, "_trackable_children"): | ||
| clean_dict[k] = v | ||
| if clean_dict: | ||
| clean_children[name] = clean_dict | ||
| # Skip if we can't unwrap safely |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for unwrapping _DictWrapper objects can be made more concise and Pythonic by using list and dictionary comprehensions. This would improve readability.
# For list-like _DictWrapper (e.g., transformer_layers)
if hasattr(child, "_data") and isinstance(
child._data, list
):
# Create a clean list of the trackable items
clean_list = [
item
for item in child._data
if hasattr(item, "_trackable_children")
]
if clean_list:
clean_children[name] = clean_list
# For dict-like _DictWrapper
elif hasattr(child, "_data") and isinstance(
child._data, dict
):
clean_dict = {
k: v
for k, v in child._data.items()
if hasattr(v, "_trackable_children")
}
if clean_dict:
clean_children[name] = clean_dict
# Skip if we can't unwrap safely
Added support for model export to keras-hub models.
This PR requires keras-team/keras#21674 as prerequisite, the export feature in keras.
Then it is built on top of that.
Simple Demo
Complete Numeric verification tests multiple models for numeric verifications.
Verified models: